{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "google",
   "metadata": {},
   "source": [
    "##### Copyright 2023 Google LLC."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "apache",
   "metadata": {},
   "source": [
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "basename",
   "metadata": {},
   "source": [
    "# scheduling_with_transitions_sat"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "link",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "<td>\n",
    "<a href=\"https://colab.research.google.com/github/google/or-tools/blob/main/examples/notebook/contrib/scheduling_with_transitions_sat.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\"/>Run in Google Colab</a>\n",
    "</td>\n",
    "<td>\n",
    "<a href=\"https://github.com/google/or-tools/blob/main/examples/contrib/scheduling_with_transitions_sat.py\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\"/>View source on GitHub</a>\n",
    "</td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "doc",
   "metadata": {},
   "source": [
    "First, you must install [ortools](https://pypi.org/project/ortools/) package in this colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "install",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install ortools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "description",
   "metadata": {},
   "source": [
    "Scheduling problem with transition time between tasks and transitions costs.\n",
    "\n",
    "@author: CSLiu2\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import collections\n",
    "\n",
    "from ortools.sat.python import cp_model\n",
    "from google.protobuf import text_format\n",
    "\n",
    "#----------------------------------------------------------------------------\n",
    "# Command line arguments.\n",
    "PARSER = argparse.ArgumentParser()\n",
    "PARSER.add_argument(\n",
    "    '--problem_instance', default=0, type=int, help='Problem instance.')\n",
    "PARSER.add_argument(\n",
    "    '--output_proto',\n",
    "    default='',\n",
    "    help='Output file to write the cp_model proto to.')\n",
    "PARSER.add_argument('--params', default='', help='Sat solver parameters.')\n",
    "\n",
    "\n",
    "#----------------------------------------------------------------------------\n",
    "# Intermediate solution printer\n",
    "class SolutionPrinter(cp_model.CpSolverSolutionCallback):\n",
    "  \"\"\"Print intermediate solutions.\"\"\"\n",
    "\n",
    "  def __init__(self, makespan):\n",
    "    cp_model.CpSolverSolutionCallback.__init__(self)\n",
    "    self.__solution_count = 0\n",
    "    self.__makespan = makespan\n",
    "\n",
    "  def OnSolutionCallback(self):\n",
    "    print('Solution %i, time = %f s, objective = %i, makespan = %i' %\n",
    "          (self.__solution_count, self.WallTime(), self.ObjectiveValue(),\n",
    "           self.Value(self.__makespan)))\n",
    "    self.__solution_count += 1\n",
    "\n",
    "\n",
    "def main(args):\n",
    "  \"\"\"Solves the scheduling with transitions problem.\"\"\"\n",
    "\n",
    "  instance = args.problem_instance\n",
    "  parameters = args.params\n",
    "  output_proto = args.output_proto\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Data.\n",
    "  small_jobs = [[[(100, 0, 'R6'), (2, 1, 'R6')]],\n",
    "                [[(2, 0, 'R3'), (100, 1, 'R3')]],\n",
    "                [[(100, 0, 'R1'), (16, 1, 'R1')]],\n",
    "                [[(1, 0, 'R1'), (38, 1, 'R1')]], [[(14, 0, 'R1'), (10, 1,\n",
    "                                                                   'R1')]],\n",
    "                [[(16, 0, 'R3'), (17, 1, 'R3')]],\n",
    "                [[(14, 0, 'R3'), (14, 1, 'R3')]],\n",
    "                [[(14, 0, 'R3'), (15, 1, 'R3')]],\n",
    "                [[(14, 0, 'R3'), (13, 1, 'R3')]],\n",
    "                [[(100, 0, 'R1'), (38, 1, 'R1')]]]\n",
    "\n",
    "  large_jobs = [\n",
    "      [[(-1, 0, 'R1'), (10, 1, 'R1')]], [[(9, 0, 'R3'),\n",
    "                                          (22, 1, 'R3')]],\n",
    "      [[(-1, 0, 'R3'), (13, 1, 'R3')]], [[(-1, 0, 'R3'), (38, 1, 'R3')]],\n",
    "      [[(-1, 0, 'R3'), (38, 1, 'R3')]], [[(-1, 0, 'R3'), (16, 1, 'R3')]],\n",
    "      [[(-1, 0, 'R3'), (11, 1, 'R3')]], [[(-1, 0, 'R3'), (13, 1, 'R3')]],\n",
    "      [[(13, 0, 'R3'), (-1, 1, 'R3')]], [[(13, 0, 'R3'), (-1, 1, 'R3')]],\n",
    "      [[(-1, 0, 'R3'), (15, 1, 'R3')]], [[(12, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(14, 0, 'R1'), (-1, 1, 'R1')]], [[(13, 0, 'R3'), (-1, 1, 'R3')]],\n",
    "      [[(-1, 0, 'R3'), (15, 1, 'R3')]], [[(14, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(13, 0, 'R3'), (-1, 1, 'R3')]], [[(14, 0, 'R3'), (-1, 1, 'R3')]],\n",
    "      [[(13, 0, 'R1'), (-1, 1, 'R1')]], [[(15, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(-1, 0, 'R2'), (16, 1, 'R2')]], [[(-1, 0, 'R2'), (16, 1, 'R2')]],\n",
    "      [[(-1, 0, 'R5'), (27, 1, 'R5')]], [[(-1, 0, 'R5'), (13, 1, 'R5')]],\n",
    "      [[(-1, 0, 'R5'), (13, 1, 'R5')]], [[(-1, 0, 'R5'), (13, 1, 'R5')]],\n",
    "      [[(13, 0, 'R1'), (-1, 1, 'R1')]], [[(-1, 0, 'R1'), (17, 1, 'R1')]],\n",
    "      [[(14, 0, 'R4'), (-1, 1, 'R4')]], [[(13, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(16, 0, 'R4'), (-1, 1, 'R4')]], [[(16, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(16, 0, 'R4'), (-1, 1, 'R4')]], [[(16, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(13, 0, 'R1'), (-1, 1, 'R1')]], [[(13, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(14, 0, 'R4'), (-1, 1, 'R4')]], [[(13, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(12, 0, 'R1'), (-1, 1, 'R1')]], [[(14, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(-1, 0, 'R5'), (14, 1, 'R5')]], [[(14, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(14, 0, 'R4'), (-1, 1, 'R4')]], [[(14, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(14, 0, 'R4'), (-1, 1, 'R4')]], [[(-1, 0, 'R1'), (21, 1, 'R1')]],\n",
    "      [[(-1, 0, 'R1'), (21, 1, 'R1')]], [[(-1, 0, 'R1'), (21, 1, 'R1')]],\n",
    "      [[(13, 0, 'R6'), (-1, 1, 'R6')]], [[(13, 0, 'R2'), (-1, 1, 'R2')]],\n",
    "      [[(-1, 0, 'R6'), (12, 1, 'R6')]], [[(13, 0, 'R1'), (-1, 1, 'R1')]],\n",
    "      [[(13, 0, 'R1'), (-1, 1, 'R1')]], [[(-1, 0, 'R6'), (14, 1, 'R6')]],\n",
    "      [[(-1, 0, 'R5'), (5, 1, 'R5')]], [[(3, 0, 'R2'), (-1, 1, 'R2')]],\n",
    "      [[(-1, 0, 'R1'), (21, 1, 'R1')]], [[(-1, 0, 'R1'), (21, 1, 'R1')]],\n",
    "      [[(-1, 0, 'R1'), (21, 1, 'R1')]], [[(-1, 0, 'R5'), (1, 1, 'R5')]],\n",
    "      [[(1, 0, 'R2'), (-1, 1, 'R2')]], [[(-1, 0, 'R2'), (19, 1, 'R2')]],\n",
    "      [[(13, 0, 'R4'), (-1, 1, 'R4')]], [[(12, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(-1, 0, 'R3'), (2, 1, 'R3')]], [[(11, 0, 'R4'), (-1, 1, 'R4')]],\n",
    "      [[(6, 0, 'R6'), (-1, 1, 'R6')]], [[(6, 0, 'R6'), (-1, 1, 'R6')]],\n",
    "      [[(1, 0, 'R2'), (-1, 1, 'R2')]], [[(12, 0, 'R4'), (-1, 1, 'R4')]]\n",
    "  ]\n",
    "\n",
    "  jobs = small_jobs if instance == 0 else large_jobs\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Helper data.\n",
    "  num_jobs = len(jobs)\n",
    "  all_jobs = range(num_jobs)\n",
    "  num_machines = 2\n",
    "  all_machines = range(num_machines)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Model.\n",
    "  model = cp_model.CpModel()\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Compute a maximum makespan greedily.\n",
    "  horizon = 0\n",
    "  for job in jobs:\n",
    "    for task in job:\n",
    "      max_task_duration = 0\n",
    "      for alternative in task:\n",
    "        max_task_duration = max(max_task_duration, alternative[0])\n",
    "      horizon += max_task_duration\n",
    "\n",
    "  print('Horizon = %i' % horizon)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Global storage of variables.\n",
    "  intervals_per_machines = collections.defaultdict(list)\n",
    "  presences_per_machines = collections.defaultdict(list)\n",
    "  starts_per_machines = collections.defaultdict(list)\n",
    "  ends_per_machines = collections.defaultdict(list)\n",
    "  resources_per_machines = collections.defaultdict(list)\n",
    "  ranks_per_machines = collections.defaultdict(list)\n",
    "\n",
    "  job_starts = {}  # indexed by (job_id, task_id).\n",
    "  job_presences = {}  # indexed by (job_id, task_id, alt_id).\n",
    "  job_ranks = {}  # indexed by (job_id, task_id, alt_id).\n",
    "  job_ends = []  # indexed by job_id\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Scan the jobs and create the relevant variables and intervals.\n",
    "  for job_id in all_jobs:\n",
    "    job = jobs[job_id]\n",
    "    num_tasks = len(job)\n",
    "    previous_end = None\n",
    "    for task_id in range(num_tasks):\n",
    "      task = job[task_id]\n",
    "\n",
    "      min_duration = task[0][0]\n",
    "      max_duration = task[0][0]\n",
    "\n",
    "      num_alternatives = len(task)\n",
    "      all_alternatives = range(num_alternatives)\n",
    "\n",
    "      for alt_id in range(1, num_alternatives):\n",
    "        alt_duration = task[alt_id][0]\n",
    "        min_duration = min(min_duration, alt_duration)\n",
    "        max_duration = max(max_duration, alt_duration)\n",
    "\n",
    "      # Create main interval for the task.\n",
    "      suffix_name = '_j%i_t%i' % (job_id, task_id)\n",
    "      start = model.NewIntVar(0, horizon, 'start' + suffix_name)\n",
    "      duration = model.NewIntVar(min_duration, max_duration,\n",
    "                                 'duration' + suffix_name)\n",
    "      end = model.NewIntVar(0, horizon, 'end' + suffix_name)\n",
    "\n",
    "      # Store the start for the solution.\n",
    "      job_starts[(job_id, task_id)] = start\n",
    "\n",
    "      # Add precedence with previous task in the same job.\n",
    "      if previous_end:\n",
    "        model.Add(start >= previous_end)\n",
    "      previous_end = end\n",
    "\n",
    "      # Create alternative intervals.\n",
    "      l_presences = []\n",
    "      for alt_id in all_alternatives:\n",
    "        ### add to link interval with eqp constraint.\n",
    "        ### process time = -1 cannot be processed at this machine.\n",
    "        if jobs[job_id][task_id][alt_id][0] == -1:\n",
    "          continue\n",
    "        alt_suffix = '_j%i_t%i_a%i' % (job_id, task_id, alt_id)\n",
    "        l_presence = model.NewBoolVar('presence' + alt_suffix)\n",
    "        l_start = model.NewIntVar(0, horizon, 'start' + alt_suffix)\n",
    "        l_duration = task[alt_id][0]\n",
    "        l_end = model.NewIntVar(0, horizon, 'end' + alt_suffix)\n",
    "        l_interval = model.NewOptionalIntervalVar(\n",
    "            l_start, l_duration, l_end, l_presence, 'interval' + alt_suffix)\n",
    "        l_rank = model.NewIntVar(-1, num_jobs, 'rank' + alt_suffix)\n",
    "        l_presences.append(l_presence)\n",
    "        l_machine = task[alt_id][1]\n",
    "        l_type = task[alt_id][2]\n",
    "\n",
    "        # Link the original variables with the local ones.\n",
    "        model.Add(start == l_start).OnlyEnforceIf(l_presence)\n",
    "        model.Add(duration == l_duration).OnlyEnforceIf(l_presence)\n",
    "        model.Add(end == l_end).OnlyEnforceIf(l_presence)\n",
    "\n",
    "        # Add the local variables to the right machine.\n",
    "        intervals_per_machines[l_machine].append(l_interval)\n",
    "        starts_per_machines[l_machine].append(l_start)\n",
    "        ends_per_machines[l_machine].append(l_end)\n",
    "        presences_per_machines[l_machine].append(l_presence)\n",
    "        resources_per_machines[l_machine].append(l_type)\n",
    "        ranks_per_machines[l_machine].append(l_rank)\n",
    "\n",
    "        # Store the variables for the solution.\n",
    "        job_presences[(job_id, task_id, alt_id)] = l_presence\n",
    "        job_ranks[(job_id, task_id, alt_id)] = l_rank\n",
    "\n",
    "      # Only one machine can process each lot.\n",
    "      model.Add(sum(l_presences) == 1)\n",
    "\n",
    "    job_ends.append(previous_end)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Create machines constraints nonoverlap process\n",
    "  for machine_id in all_machines:\n",
    "    intervals = intervals_per_machines[machine_id]\n",
    "    if len(intervals) > 1:\n",
    "      model.AddNoOverlap(intervals)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Transition times and transition costs using a circuit constraint.\n",
    "  switch_literals = []\n",
    "  for machine_id in all_machines:\n",
    "    machine_starts = starts_per_machines[machine_id]\n",
    "    machine_ends = ends_per_machines[machine_id]\n",
    "    machine_presences = presences_per_machines[machine_id]\n",
    "    machine_resources = resources_per_machines[machine_id]\n",
    "    machine_ranks = ranks_per_machines[machine_id]\n",
    "    intervals = intervals_per_machines[machine_id]\n",
    "    arcs = []\n",
    "    num_machine_tasks = len(machine_starts)\n",
    "    all_machine_tasks = range(num_machine_tasks)\n",
    "\n",
    "    for i in all_machine_tasks:\n",
    "      # Initial arc from the dummy node (0) to a task.\n",
    "      start_lit = model.NewBoolVar('')\n",
    "      arcs.append([0, i + 1, start_lit])\n",
    "      # If this task is the first, set both rank and start to 0.\n",
    "      model.Add(machine_ranks[i] == 0).OnlyEnforceIf(start_lit)\n",
    "      model.Add(machine_starts[i] == 0).OnlyEnforceIf(start_lit)\n",
    "      # Final arc from an arc to the dummy node.\n",
    "      arcs.append([i + 1, 0, model.NewBoolVar('')])\n",
    "      # Self arc if the task is not performed.\n",
    "      arcs.append([i + 1, i + 1, machine_presences[i].Not()])\n",
    "      model.Add(machine_ranks[i] == -1).OnlyEnforceIf(\n",
    "          machine_presences[i].Not())\n",
    "\n",
    "      for j in all_machine_tasks:\n",
    "        if i == j:\n",
    "          continue\n",
    "\n",
    "        lit = model.NewBoolVar('%i follows %i' % (j, i))\n",
    "        arcs.append([i + 1, j + 1, lit])\n",
    "        model.AddImplication(lit, machine_presences[i])\n",
    "        model.AddImplication(lit, machine_presences[j])\n",
    "\n",
    "        # Maintain rank incrementally.\n",
    "        model.Add(machine_ranks[j] == machine_ranks[i] + 1).OnlyEnforceIf(lit)\n",
    "\n",
    "        # Compute the transition time if task j is the successor of task i.\n",
    "        if machine_resources[i] != machine_resources[j]:\n",
    "          transition_time = 3\n",
    "          switch_literals.append(lit)\n",
    "        else:\n",
    "          transition_time = 0\n",
    "        # We add the reified transition to link the literals with the times\n",
    "        # of the tasks.\n",
    "        model.Add(machine_starts[j] == machine_ends[i] +\n",
    "                  transition_time).OnlyEnforceIf(lit)\n",
    "    if arcs:\n",
    "        model.AddCircuit(arcs)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Objective.\n",
    "  makespan = model.NewIntVar(0, horizon, 'makespan')\n",
    "  model.AddMaxEquality(makespan, job_ends)\n",
    "  makespan_weight = 1\n",
    "  transition_weight = 5\n",
    "  model.Minimize(makespan * makespan_weight +\n",
    "                 sum(switch_literals) * transition_weight)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Write problem to file.\n",
    "  if output_proto:\n",
    "    print('Writing proto to %s' % output_proto)\n",
    "    with open(output_proto, 'w') as text_file:\n",
    "      text_file.write(str(model))\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Solve.\n",
    "  solver = cp_model.CpSolver()\n",
    "  solver.parameters.max_time_in_seconds = 60 * 60 * 2\n",
    "  if parameters:\n",
    "    text_format.Merge(parameters, solver.parameters)\n",
    "  solution_printer = SolutionPrinter(makespan)\n",
    "  status = solver.Solve(model, solution_printer)\n",
    "\n",
    "  #----------------------------------------------------------------------------\n",
    "  # Print solution.\n",
    "  if status == cp_model.FEASIBLE or status == cp_model.OPTIMAL:\n",
    "    for job_id in all_jobs:\n",
    "      for task_id in range(len(jobs[job_id])):\n",
    "        start_value = solver.Value(job_starts[(job_id, task_id)])\n",
    "        machine = 0\n",
    "        duration = 0\n",
    "        select = 0\n",
    "        rank = -1\n",
    "\n",
    "        for alt_id in range(len(jobs[job_id][task_id])):\n",
    "          if jobs[job_id][task_id][alt_id][0] == -1:\n",
    "            continue\n",
    "\n",
    "          if solver.BooleanValue(job_presences[(job_id, task_id, alt_id)]):\n",
    "            duration = jobs[job_id][task_id][alt_id][0]\n",
    "            machine = jobs[job_id][task_id][alt_id][1]\n",
    "            select = alt_id\n",
    "            rank = solver.Value(job_ranks[(job_id, task_id, alt_id)])\n",
    "\n",
    "        print(\n",
    "            '  Job %i starts at %i (alt %i, duration %i) with rank %i on machine %i'\n",
    "            % (job_id, start_value, select, duration, rank, machine))\n",
    "\n",
    "    print('Solve status: %s' % solver.StatusName(status))\n",
    "    print('Objective value: %i' % solver.ObjectiveValue())\n",
    "    print('Makespan: %i' % solver.Value(makespan))\n",
    "\n",
    "\n",
    "main(PARSER.parse_args())\n",
    "\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
